{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Natural language inference: models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Christopher Potts\"\n",
    "__version__ = \"CS224u, Stanford, Fall 2020\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Overview](#Overview)\n",
    "1. [Set-up](#Set-up)\n",
    "1. [Sparse feature representations](#Sparse-feature-representations)\n",
    "  1. [Feature representations](#Feature-representations)\n",
    "  1. [Model wrapper for hyperparameter search](#Model-wrapper-for-hyperparameter-search)\n",
    "  1. [Assessment](#Assessment)\n",
    "1. [Hypothesis-only baselines](#Hypothesis-only-baselines)\n",
    "1. [Sentence-encoding models](#Sentence-encoding-models)\n",
    "  1. [Dense representations](#Dense-representations)\n",
    "  1. [Sentence-encoding RNNs](#Sentence-encoding-RNNs)\n",
    "  1. [Other sentence-encoding model ideas](#Other-sentence-encoding-model-ideas)\n",
    "1. [Chained models](#Chained-models)\n",
    "  1. [Simple RNN](#Simple-RNN)\n",
    "  1. [Separate premise and hypothesis RNNs](#Separate-premise-and-hypothesis-RNNs)\n",
    "1. [Attention mechanisms](#Attention-mechanisms)\n",
    "1. [Error analysis with the MultiNLI annotations](#Error-analysis-with-the-MultiNLI-annotations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "This notebook defines and explores a number of models for NLI. The general plot is familiar from [our work with the Stanford Sentiment Treebank](sst_01_overview.ipynb):\n",
    "\n",
    "1. Models based on sparse feature representations\n",
    "1. Linear classifiers and feed-forward neural classifiers using dense feature representations\n",
    "1. Recurrent neural networks (and, briefly, tree-structured neural networks)\n",
    "\n",
    "The twist here is that, while NLI is another classification problem, the inputs have important high-level structure: __a premise__ and __a hypothesis__. This invites exploration of a host of neural designs:\n",
    "\n",
    "* In __sentence-encoding__ models, the premise and hypothesis are analyzed separately, and combined only for the final classification step.\n",
    "\n",
    "* In __chained__ models, the premise is processed first, then the hypotheses, giving a unified representation of the pair.\n",
    "\n",
    "NLI resembles sequence-to-sequence problems like __machine translation__ and __language modeling__. The central modeling difference is that NLI doesn't produce an output sequence, but rather consumes two sequences to produce a label. Still, there are enough affinities that many ideas have been shared among these areas."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set-up\n",
    "\n",
    "See [the previous notebook](nli_01_task_and_data.ipynb#Set-up) for set-up instructions for this unit. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "from itertools import product\n",
    "import nli\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "from sklearn.exceptions import ConvergenceWarning\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data\n",
    "from torch_model_base import TorchModelBase\n",
    "from torch_rnn_classifier import TorchRNNClassifier, TorchRNNModel\n",
    "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
    "import utils\n",
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "utils.fix_random_seeds()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "GLOVE_HOME = os.path.join('data', 'glove.6B')\n",
    "\n",
    "DATA_HOME = os.path.join(\"data\", \"nlidata\")\n",
    "\n",
    "SNLI_HOME = os.path.join(DATA_HOME, \"snli_1.0\")\n",
    "\n",
    "MULTINLI_HOME = os.path.join(DATA_HOME, \"multinli_1.0\")\n",
    "\n",
    "ANNOTATIONS_HOME = os.path.join(DATA_HOME, \"multinli_1.0_annotations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sparse feature representations\n",
    "\n",
    "We begin by looking at models based in sparse, hand-built feature representations. As in earlier units of the course, we will see that __these models are competitive__: easy to design, fast to optimize, and highly effective."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feature representations\n",
    "\n",
    "The guiding idea for NLI sparse features is that one wants to knit together the premise and hypothesis, so that the model can learn about their relationships rather than just about each part separately."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `word_overlap_phi`, we just get the set of words that occur in both the premise and hypothesis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def word_overlap_phi(t1, t2):\n",
    "    \"\"\"\n",
    "    Basis for features for the words in both the premise and hypothesis.\n",
    "    Downcases all words.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    t1, t2 : `nltk.tree.Tree`\n",
    "        As given by `str2tree`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    defaultdict\n",
    "       Maps each word in both `t1` and `t2` to 1.\n",
    "\n",
    "    \"\"\"\n",
    "    words1 = {w.lower() for w in t1.leaves()}\n",
    "    words2 = {w.lower() for w in t2.leaves()}\n",
    "    return Counter(words1 & words2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `word_cross_product_phi`, we count all the pairs $(w_{1}, w_{2})$ where $w_{1}$ is a word from the premise and $w_{2}$ is a word from the hypothesis. This creates a very large feature space. These models are very strong right out of the box, and they can be supplemented with more fine-grained features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def word_cross_product_phi(t1, t2):\n",
    "    \"\"\"\n",
    "    Basis for cross-product features. Downcases all words.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    t1, t2 : `nltk.tree.Tree`\n",
    "        As given by `str2tree`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    defaultdict\n",
    "        Maps each (w1, w2) in the cross-product of `t1.leaves()` and\n",
    "        `t2.leaves()` (both downcased) to its count. This is a\n",
    "        multi-set cross-product (repetitions matter).\n",
    "\n",
    "    \"\"\"\n",
    "    words1 = [w.lower() for w in t1.leaves()]\n",
    "    words2 = [w.lower() for w in t2.leaves()]\n",
    "    return Counter([(w1, w2) for w1, w2 in product(words1, words2)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model wrapper for hyperparameter search\n",
    "\n",
    "Our experiment framework is basically the same as the one we used for the Stanford Sentiment Treebank. \n",
    "\n",
    "For a full evaluation, we would like to search for the best hyperparameters. However, SNLI is very large, so each evaluation is very expensive. To try to keep this under control, we can set the optimizer to do just a few epochs of training during the search phase. The assumption here is that the best parameters actually emerge as best early in the process. This is by no means guaranteed, but it seems like a good way to balance doing serious hyperparameter search with the costs of doing dozens or even thousands of experiments. (See also [the discussion of hyperparameter search in the evaluation methods notebook](evaluation_methods.ipynb#Hyperparameter-optimization).)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_softmax_with_hyperparameter_search(X, y):\n",
    "    \"\"\"\n",
    "    A MaxEnt model of dataset with hyperparameter cross-validation.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    X : 2d np.array\n",
    "        The matrix of features, one example per row.\n",
    "\n",
    "    y : list\n",
    "        The list of labels for rows in `X`.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    sklearn.linear_model.LogisticRegression\n",
    "        A trained model instance, the best model found.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    mod = LogisticRegression(\n",
    "        fit_intercept=True,\n",
    "        max_iter=3,  ## A small number of iterations.\n",
    "        solver='liblinear',\n",
    "        multi_class='ovr')\n",
    "\n",
    "    param_grid = {\n",
    "        'C': [0.4, 0.6, 0.8, 1.0],\n",
    "        'penalty': ['l1','l2']}\n",
    "\n",
    "    with warnings.catch_warnings():\n",
    "        warnings.simplefilter(\"ignore\")\n",
    "        bestmod = utils.fit_classifier_with_hyperparameter_search(\n",
    "            X, y, mod, param_grid=param_grid, cv=3)\n",
    "\n",
    "    return bestmod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Assessment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best params: {'C': 0.4, 'penalty': 'l1'}\n",
      "Best score: 0.704\n",
      "CPU times: user 18min 40s, sys: 7min 11s, total: 25min 52s\n",
      "Wall time: 10min 35s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "word_cross_product_experiment_xval = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=word_cross_product_phi,\n",
    "    train_func=fit_softmax_with_hyperparameter_search,\n",
    "    assess_reader=None,\n",
    "    verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_word_cross_product_model = word_cross_product_experiment_xval['model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# `word_cross_product_experiment_xval` consumes a lot of memory, and we\n",
    "# won't make use of it outside of the model, so we can remove it now.\n",
    "del word_cross_product_experiment_xval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_optimized_word_cross_product(X, y):\n",
    "    optimized_word_cross_product_model.max_iter = 1000 # To convergence in this phase!\n",
    "    optimized_word_cross_product_model.fit(X, y)\n",
    "    return optimized_word_cross_product_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.782     0.766     0.774      3278\n",
      "   entailment      0.743     0.811     0.775      3329\n",
      "      neutral      0.725     0.672     0.698      3235\n",
      "\n",
      "     accuracy                          0.750      9842\n",
      "    macro avg      0.750     0.749     0.749      9842\n",
      " weighted avg      0.750     0.750     0.749      9842\n",
      "\n",
      "CPU times: user 6min 19s, sys: 5.17 s, total: 6min 24s\n",
      "Wall time: 6min 22s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=word_cross_product_phi,\n",
    "    train_func=fit_optimized_word_cross_product,\n",
    "    assess_reader=nli.SNLIDevReader(SNLI_HOME))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected `word_cross_product_phi` is reasonably strong. This model is similar to (a simplified version of) the baseline \"Lexicalized Classifier\" in [the original SNLI paper by Bowman et al.](https://www.aclweb.org/anthology/D15-1075/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hypothesis-only baselines"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In an outstanding project for this course in 2016, [Leonid Keselman](https://leonidk.com) observed that [one can do much better than chance on SNLI by processing only the hypothesis](https://leonidk.com/stanford/cs224u.html). This relates to [observations we made in the word-level homework/bake-off](hw_wordentail.ipynb) about how certain terms will tend to appear more on the right in entailment pairs than on the left. In 2018, a number of groups independently (re-)discovered this fact and published analyses: [Poliak et al. 2018](https://www.aclweb.org/anthology/S18-2023/), [Tsuchiya 2018](https://www.aclweb.org/anthology/L18-1239/), [Gururangan et al. 2018](https://www.aclweb.org/anthology/N18-2017/). Let's build on this insight by fitting a hypothesis-only model that seems comparable to the cross-product-based model we just looked at:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hypothesis_only_unigrams_phi(t1, t2):\n",
    "    return Counter(t2.leaves())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_softmax(X, y):\n",
    "    mod = LogisticRegression(\n",
    "        fit_intercept=True,\n",
    "        solver='liblinear',\n",
    "        multi_class='ovr')\n",
    "    mod.fit(X, y)\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.654     0.631     0.643      3278\n",
      "   entailment      0.639     0.715     0.675      3329\n",
      "      neutral      0.670     0.613     0.640      3235\n",
      "\n",
      "     accuracy                          0.653      9842\n",
      "    macro avg      0.655     0.653     0.653      9842\n",
      " weighted avg      0.654     0.653     0.653      9842\n",
      "\n",
      "CPU times: user 16min 52s, sys: 13min 55s, total: 30min 48s\n",
      "Wall time: 4min 46s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=hypothesis_only_unigrams_phi,\n",
    "    train_func=fit_softmax,\n",
    "    assess_reader=nli.SNLIDevReader(SNLI_HOME))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Chance performance on SNLI is 0.33 accuracy/F1. The above makes it clear that using chance as a baseline will overstate how much traction a model has actually gotten on the SNLI problem. The hypothesis-only baseline is better for this kind of calibration. \n",
    "\n",
    "Ideally, for each model one explores, one would fit a minimally different hypothesis-only model as a baseline. To avoid undue complexity, I won't do that here, but we will use the above results to provide informal context, and I will sketch reasonable hypothesis-only baselines for each model we consider."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sentence-encoding models\n",
    "\n",
    "We turn now to sentence-encoding models. The hallmark of these is that the premise and hypothesis get their own representation in some sense, and then those representations are combined to predict the label. [Bowman et al. 2015](http://aclweb.org/anthology/D/D15/D15-1075.pdf) explore models of this form as part of introducing SNLI."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dense representations\n",
    "\n",
    "Perhaps the simplest sentence-encoding model sums (or averages, etc.) the word representations for the premise, does the same for the hypothesis, and concatenates those two representations for use as the input to a linear classifier. \n",
    "\n",
    "Here's a diagram that is meant to suggest the full space of models of this form:\n",
    "\n",
    "<img src=\"fig/nli-softmax.png\" width=800 />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's an implementation of this model where \n",
    "\n",
    "* The embedding is GloVe.\n",
    "* The word representations are summed.\n",
    "* The premise and hypothesis vectors are concatenated.\n",
    "* A softmax classifier is used at the top."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_lookup = utils.glove2dict(\n",
    "    os.path.join(GLOVE_HOME, 'glove.6B.300d.txt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def glove_leaves_phi(t1, t2, np_func=np.mean):\n",
    "    \"\"\"\n",
    "    Represent `t1` and `t2 as a combination of the vector of their words,\n",
    "    and concatenate these two combinator vectors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    t1 : nltk.Tree\n",
    "\n",
    "    t2 : nltk.Tree\n",
    "\n",
    "    np_func : function\n",
    "        A numpy matrix operation that can be applied columnwise,\n",
    "        like `np.mean`, `np.sum`, or `np.prod`. The requirement is that\n",
    "        the function take `axis=0` as one of its arguments (to ensure\n",
    "        columnwise combination) and that it return a vector of a\n",
    "        fixed length, no matter what the size of the tree is.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    np.array\n",
    "\n",
    "    \"\"\"\n",
    "    prem_vecs = _get_tree_vecs(t1, glove_lookup, np_func)\n",
    "    hyp_vecs = _get_tree_vecs(t2, glove_lookup, np_func)\n",
    "    return np.concatenate((prem_vecs, hyp_vecs))\n",
    "\n",
    "\n",
    "def _get_tree_vecs(tree, lookup, np_func):\n",
    "    allvecs = np.array([lookup[w] for w in tree.leaves() if w in lookup])\n",
    "    if len(allvecs) == 0:\n",
    "        dim = len(next(iter(lookup.values())))\n",
    "        feats = np.zeros(dim)\n",
    "    else:\n",
    "        feats = np_func(allvecs, axis=0)\n",
    "    return feats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best params: {'C': 1.0, 'penalty': 'l1'}\n",
      "Best score: 0.550\n",
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.549     0.531     0.540      3278\n",
      "   entailment      0.547     0.565     0.556      3329\n",
      "      neutral      0.570     0.571     0.570      3235\n",
      "\n",
      "     accuracy                          0.555      9842\n",
      "    macro avg      0.556     0.555     0.555      9842\n",
      " weighted avg      0.555     0.555     0.555      9842\n",
      "\n",
      "CPU times: user 11min 39s, sys: 1min 16s, total: 12min 56s\n",
      "Wall time: 12min 12s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=glove_leaves_phi,\n",
    "    train_func=fit_softmax_with_hyperparameter_search,\n",
    "    assess_reader=nli.SNLIDevReader(SNLI_HOME),\n",
    "    vectorize=False)  # Ask `experiment` not to featurize; we did it already."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The hypothesis-only counterpart of this model is very clear: we would just encode `t2` with GloVe, leaving `t1` out entirely."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As an elaboration of this approach, it is worth considering the `VecAvg` model we studied in [sst_03_neural_networks.ipynb](#sst_03_neural_networks.ipynb#The-VecAvg-baseline-from-Socher-et-al.-2013), which updates the initial vector representations during learning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sentence-encoding RNNs\n",
    "\n",
    "A more sophisticated sentence-encoding model processes the premise and hypothesis with separate RNNs and uses the concatenation of their final states as the basis for the classification decision at the top:\n",
    "\n",
    "<img src=\"fig/nli-rnn-sentencerep.png\" width=800 />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is relatively straightforward to extend `torch_rnn_classifier` so that it can handle this architecture:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### A sentence-encoding dataset\n",
    "\n",
    "Whereas `torch_rnn_classifier.TorchRNNDataset` creates batches that consist of `(sequence, sequence_length, label)` triples, the sentence encoding model requires us to double the first two components. The most important features of this is `collate_fn`, which determines what the batches look like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchRNNSentenceEncoderDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, prem_seqs, hyp_seqs, prem_lengths, hyp_lengths, y=None):\n",
    "        self.prem_seqs = prem_seqs\n",
    "        self.hyp_seqs = hyp_seqs\n",
    "        self.prem_lengths = prem_lengths\n",
    "        self.hyp_lengths = hyp_lengths\n",
    "        self.y = y\n",
    "        assert len(self.prem_seqs) == len(self.hyp_seqs)\n",
    "        assert len(self.hyp_seqs) == len(self.prem_lengths)\n",
    "        assert len(self.prem_lengths) == len(self.hyp_lengths)\n",
    "        if self.y is not None:\n",
    "            assert len(self.hyp_lengths) == len(self.y)\n",
    "\n",
    "    @staticmethod\n",
    "    def collate_fn(batch):\n",
    "        batch = list(zip(*batch))\n",
    "        X_prem = torch.nn.utils.rnn.pad_sequence(batch[0], batch_first=True)\n",
    "        X_hyp = torch.nn.utils.rnn.pad_sequence(batch[1], batch_first=True)\n",
    "        prem_lengths = torch.tensor(batch[2])\n",
    "        hyp_lengths = torch.tensor(batch[3])\n",
    "        if len(batch) == 5:\n",
    "            y = torch.tensor(batch[4])\n",
    "            return X_prem, X_hyp, prem_lengths, hyp_lengths, y\n",
    "        else:\n",
    "            return X_prem, X_hyp, prem_lengths, hyp_lengths\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.prem_seqs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.y is None:\n",
    "            return (self.prem_seqs[idx], self.hyp_seqs[idx],\n",
    "                    self.prem_lengths[idx], self.hyp_lengths[idx])\n",
    "        else:\n",
    "            return (self.prem_seqs[idx], self.hyp_seqs[idx],\n",
    "                    self.prem_lengths[idx], self.hyp_lengths[idx],\n",
    "                    self.y[idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### A sentence-encoding model\n",
    "\n",
    "With `TorchRNNSentenceEncoderClassifierModel`, we create a new `nn.Module` that functions just like the existing `torch_rnn_classifier.TorchRNNClassifierModel`, except that it takes two RNN instances as arguments and combines their final output states to create the classifier input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchRNNSentenceEncoderClassifierModel(nn.Module):\n",
    "    def __init__(self, prem_rnn, hyp_rnn, output_dim):\n",
    "        super().__init__()\n",
    "        self.prem_rnn = prem_rnn\n",
    "        self.hyp_rnn = hyp_rnn\n",
    "        self.output_dim = output_dim\n",
    "        self.bidirectional = self.prem_rnn.bidirectional\n",
    "        # Doubled because we concatenate the final states of\n",
    "        # the premise and hypothesis RNNs:\n",
    "        self.classifier_dim = self.prem_rnn.hidden_dim * 2\n",
    "        # Bidirectionality doubles it again:\n",
    "        if self.bidirectional:\n",
    "            self.classifier_dim *= 2\n",
    "        self.classifier_layer = nn.Linear(\n",
    "            self.classifier_dim, self.output_dim)\n",
    "\n",
    "    def forward(self, X_prem, X_hyp, prem_lengths, hyp_lengths):\n",
    "        # Premise:\n",
    "        _, prem_state = self.prem_rnn(X_prem, prem_lengths)\n",
    "        prem_state = self.get_batch_final_states(prem_state)\n",
    "        # Hypothesis:\n",
    "        _, hyp_state = self.hyp_rnn(X_hyp, hyp_lengths)\n",
    "        hyp_state = self.get_batch_final_states(hyp_state)\n",
    "        # Final combination:\n",
    "        state = torch.cat((prem_state, hyp_state), dim=1)\n",
    "        # Classifier layer:\n",
    "        logits = self.classifier_layer(state)\n",
    "        return logits\n",
    "\n",
    "    def get_batch_final_states(self, state):\n",
    "        if self.prem_rnn.rnn.__class__.__name__ == 'LSTM':\n",
    "            state = state[0].squeeze(0)\n",
    "        else:\n",
    "            state = state.squeeze(0)\n",
    "        if self.bidirectional:\n",
    "            state = torch.cat((state[0], state[1]), dim=1)\n",
    "        return state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### A sentence-encoding model interface\n",
    "\n",
    "Finally, we subclass `TorchRNNClassifier`. Here, just need to redefine three methods: `build_dataset` and `build_graph` to make use of the new components above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TorchRNNSentenceEncoderClassifier(TorchRNNClassifier):\n",
    "\n",
    "    def build_dataset(self, X, y=None):\n",
    "        X_prem, X_hyp = zip(*X)\n",
    "        X_prem, prem_lengths = self._prepare_sequences(X_prem)\n",
    "        X_hyp, hyp_lengths = self._prepare_sequences(X_hyp)\n",
    "        if y is None:\n",
    "            return TorchRNNSentenceEncoderDataset(\n",
    "                X_prem, X_hyp, prem_lengths, hyp_lengths)\n",
    "        else:\n",
    "            self.classes_ = sorted(set(y))\n",
    "            self.n_classes_ = len(self.classes_)\n",
    "            class2index = dict(zip(self.classes_, range(self.n_classes_)))\n",
    "            y = [class2index[label] for label in y]\n",
    "            return TorchRNNSentenceEncoderDataset(\n",
    "                X_prem, X_hyp, prem_lengths, hyp_lengths, y)\n",
    "\n",
    "    def build_graph(self):\n",
    "        prem_rnn = TorchRNNModel(\n",
    "            vocab_size=len(self.vocab),\n",
    "            embedding=self.embedding,\n",
    "            use_embedding=self.use_embedding,\n",
    "            embed_dim=self.embed_dim,\n",
    "            rnn_cell_class=self.rnn_cell_class,\n",
    "            hidden_dim=self.hidden_dim,\n",
    "            bidirectional=self.bidirectional,\n",
    "            freeze_embedding=self.freeze_embedding)\n",
    "\n",
    "        hyp_rnn = TorchRNNModel(\n",
    "            vocab_size=len(self.vocab),\n",
    "            embedding=prem_rnn.embedding,  # Same embedding for both RNNs.\n",
    "            use_embedding=self.use_embedding,\n",
    "            embed_dim=self.embed_dim,\n",
    "            rnn_cell_class=self.rnn_cell_class,\n",
    "            hidden_dim=self.hidden_dim,\n",
    "            bidirectional=self.bidirectional,\n",
    "            freeze_embedding=self.freeze_embedding)\n",
    "\n",
    "        model = TorchRNNSentenceEncoderClassifierModel(\n",
    "            prem_rnn, hyp_rnn, output_dim=self.n_classes_)\n",
    "\n",
    "        self.embed_dim = prem_rnn.embed_dim\n",
    "\n",
    "        return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Simple example\n",
    "\n",
    "This toy problem illustrates how this works in detail:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_example():\n",
    "    vocab = ['a', 'b', '$UNK']\n",
    "\n",
    "    # Reversals are good, and other pairs are bad:\n",
    "    train = [\n",
    "        [(list('ab'), list('ba')), 'good'],\n",
    "        [(list('aab'), list('baa')), 'good'],\n",
    "        [(list('abb'), list('bba')), 'good'],\n",
    "        [(list('aabb'), list('bbaa')), 'good'],\n",
    "        [(list('ba'), list('ba')), 'bad'],\n",
    "        [(list('baa'), list('baa')), 'bad'],\n",
    "        [(list('bba'), list('bab')), 'bad'],\n",
    "        [(list('bbaa'), list('bbab')), 'bad'],\n",
    "        [(list('aba'), list('bab')), 'bad']]\n",
    "\n",
    "    test = [\n",
    "        [(list('baaa'), list('aabb')), 'bad'],\n",
    "        [(list('abaa'), list('baaa')), 'bad'],\n",
    "        [(list('bbaa'), list('bbaa')), 'bad'],\n",
    "        [(list('aaab'), list('baaa')), 'good'],\n",
    "        [(list('aaabb'), list('bbaaa')), 'good']]\n",
    "\n",
    "    mod = TorchRNNSentenceEncoderClassifier(\n",
    "        vocab,\n",
    "        max_iter=1000,\n",
    "        embed_dim=10,\n",
    "        bidirectional=True,\n",
    "        hidden_dim=10)\n",
    "\n",
    "    X, y = zip(*train)\n",
    "    mod.fit(X, y)\n",
    "\n",
    "    X_test, y_test = zip(*test)\n",
    "    preds = mod.predict(X_test)\n",
    "\n",
    "    print(\"\\nPredictions:\")\n",
    "    for ex, pred, gold in zip(X_test, preds, y_test):\n",
    "        score = \"correct\" if pred == gold else \"incorrect\"\n",
    "        print(\"{0:>6} {1:>6} - predicted: {2:>4}; actual: {3:>4} - {4}\".format(\n",
    "            \"\".join(ex[0]), \"\".join(ex[1]), pred, gold, score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 513. Training loss did not improve more than tol=1e-05. Final error is 0.002758701564744115."
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Predictions:\n",
      "  baaa   aabb - predicted:  bad; actual:  bad - correct\n",
      "  abaa   baaa - predicted:  bad; actual:  bad - correct\n",
      "  bbaa   bbaa - predicted:  bad; actual:  bad - correct\n",
      "  aaab   baaa - predicted: good; actual: good - correct\n",
      " aaabb  bbaaa - predicted: good; actual: good - correct\n"
     ]
    }
   ],
   "source": [
    "simple_example()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Example SNLI run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sentence_encoding_rnn_phi(t1, t2):\n",
    "    \"\"\"Map `t1` and `t2` to a pair of lists of leaf nodes.\"\"\"\n",
    "    return (t1.leaves(), t2.leaves())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sentence_encoding_vocab(X, n_words=None, mincount=1):\n",
    "    wc = Counter([w for pair in X for ex in pair for w in ex])\n",
    "    wc = wc.most_common(n_words) if n_words else wc.items()\n",
    "    if mincount > 1:\n",
    "        wc = {(w, c) for w, c in wc if c >= mincount}\n",
    "    vocab = {w for w, c in wc}\n",
    "    vocab.add(\"$UNK\")\n",
    "    return sorted(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_simple_sentence_encoding_rnn_with_hyperparameter_search(X, y):\n",
    "    vocab = get_sentence_encoding_vocab(X, mincount=2)\n",
    "\n",
    "    mod = TorchRNNSentenceEncoderClassifier(\n",
    "        vocab,\n",
    "        hidden_dim=300,\n",
    "        embed_dim=300,\n",
    "        bidirectional=True,\n",
    "        early_stopping=True,\n",
    "        max_iter=1)\n",
    "\n",
    "    param_grid = {\n",
    "        'batch_size': [32, 64, 128, 256],\n",
    "        'eta': [0.0001, 0.001, 0.01]}\n",
    "\n",
    "    bestmod = utils.fit_classifier_with_hyperparameter_search(\n",
    "        X, y, mod, cv=3, param_grid=param_grid)\n",
    "\n",
    "    return bestmod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 1 of 1; error is 4444.1026921272282"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best params: {'batch_size': 64, 'eta': 0.001}\n",
      "Best score: 0.653\n",
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.672     0.612     0.641     54929\n",
      "   entailment      0.669     0.678     0.673     54958\n",
      "      neutral      0.632     0.680     0.655     54924\n",
      "\n",
      "     accuracy                          0.657    164811\n",
      "    macro avg      0.658     0.657     0.656    164811\n",
      " weighted avg      0.658     0.657     0.656    164811\n",
      "\n",
      "CPU times: user 41min 23s, sys: 15.4 s, total: 41min 38s\n",
      "Wall time: 41min 31s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "sentence_encoder_rnn_experiment_xval = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=sentence_encoding_rnn_phi,\n",
    "    train_func=fit_simple_sentence_encoding_rnn_with_hyperparameter_search,\n",
    "    assess_reader=None,\n",
    "    vectorize=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_sentence_encoding_rnn = sentence_encoder_rnn_experiment_xval['model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove unneeded experimental data:\n",
    "del sentence_encoder_rnn_experiment_xval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_optimized_sentence_encoding_rnn(X, y):\n",
    "    optimized_sentence_encoding_rnn.max_iter = 1000  # Give early_stopping time!\n",
    "    optimized_sentence_encoding_rnn.fit(X, y)\n",
    "    return optimized_sentence_encoding_rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 13. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 1812.9148220475763"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.684     0.651     0.667      3278\n",
      "   entailment      0.696     0.715     0.705      3329\n",
      "      neutral      0.675     0.690     0.682      3235\n",
      "\n",
      "     accuracy                          0.685      9842\n",
      "    macro avg      0.685     0.685     0.685      9842\n",
      " weighted avg      0.685     0.685     0.685      9842\n",
      "\n",
      "CPU times: user 22min 55s, sys: 7.98 s, total: 23min 3s\n",
      "Wall time: 22min 59s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=sentence_encoding_rnn_phi,\n",
    "    train_func=fit_optimized_sentence_encoding_rnn,\n",
    "    assess_reader=nli.SNLIDevReader(SNLI_HOME),\n",
    "    vectorize=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is above our general hypothesis-only baseline ($\\approx$0.65), but it is below the simpler word cross-product model ($\\approx$0.75).\n",
    "\n",
    "A natural hypothesis-only baseline for this model be a simple `TorchRNNClassifier` that processed only the hypothesis."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Other sentence-encoding model ideas\n",
    "\n",
    "Given that [we already explored tree-structured neural networks (TreeNNs)](sst_03_neural_networks.ipynb#Tree-structured-neural-networks), it's natural to consider these as the basis for sentence-encoding NLI models:\n",
    "\n",
    "<img src=\"fig/nli-treenn.png\" width=800 />\n",
    "\n",
    "And this is just the begnning: any model used to represent sentences is presumably a candidate for use in sentence-encoding NLI!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chained models\n",
    "\n",
    "The final major class of NLI designs we look at are those in which the premise and hypothesis are processed sequentially, as a pair. These don't deliver representations of the premise or hypothesis separately. They bear the strongest resemblance to classic sequence-to-sequence models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simple RNN\n",
    "\n",
    "In the simplest version of this model, we just concatenate the premise and hypothesis. The model itself is identical to the one we used for the Stanford Sentiment Treebank:\n",
    "\n",
    "<img src=\"fig/nli-rnn-chained.png\" width=800 />"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To implement this, we can use `TorchRNNClassifier` out of the box. We just need to concatenate the leaves of the premise and hypothesis trees:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_chained_rep_rnn_phi(t1, t2):\n",
    "    \"\"\"Map `t1` and `t2` to a single list of leaf nodes.\n",
    "\n",
    "    A slight variant might insert a designated boundary symbol between\n",
    "    the premise leaves and the hypothesis leaves. Be sure to add it to\n",
    "    the vocab in that case, else it will be $UNK.\n",
    "    \"\"\"\n",
    "    return t1.leaves() + t2.leaves()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_simple_chained_rnn_with_hyperparameter_search(X, y):\n",
    "    vocab = utils.get_vocab(X, mincount=2)\n",
    "\n",
    "    mod = TorchRNNClassifier(\n",
    "        vocab,\n",
    "        hidden_dim=300,\n",
    "        embed_dim=300,\n",
    "        bidirectional=True,\n",
    "        early_stopping=True,\n",
    "        max_iter=1)\n",
    "\n",
    "    param_grid = {\n",
    "        'batch_size': [32, 64, 128, 256],\n",
    "        'eta': [0.0001, 0.001, 0.01]}\n",
    "\n",
    "    bestmod = utils.fit_classifier_with_hyperparameter_search(\n",
    "        X, y, mod, cv=3, param_grid=param_grid)\n",
    "\n",
    "    return bestmod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Finished epoch 1 of 1; error is 4347.5091073811054"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best params: {'batch_size': 64, 'eta': 0.001}\n",
      "Best score: 0.670\n",
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.658     0.705     0.681     54982\n",
      "   entailment      0.697     0.697     0.697     54867\n",
      "      neutral      0.673     0.626     0.649     54962\n",
      "\n",
      "     accuracy                          0.676    164811\n",
      "    macro avg      0.676     0.676     0.675    164811\n",
      " weighted avg      0.676     0.676     0.675    164811\n",
      "\n",
      "CPU times: user 33min 4s, sys: 9.37 s, total: 33min 13s\n",
      "Wall time: 33min 6s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "chained_rnn_experiment_xval = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=simple_chained_rep_rnn_phi,\n",
    "    train_func=fit_simple_chained_rnn_with_hyperparameter_search,\n",
    "    assess_reader=None,\n",
    "    vectorize=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_chained_rnn = chained_rnn_experiment_xval['model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "del chained_rnn_experiment_xval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_optimized_simple_chained_rnn(X, y):\n",
    "    optimized_chained_rnn.max_iter = 1000\n",
    "    optimized_chained_rnn.fit(X, y)\n",
    "    return optimized_chained_rnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 15. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 1677.3928870372474"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.766     0.733     0.749      3278\n",
      "   entailment      0.729     0.808     0.766      3329\n",
      "      neutral      0.727     0.677     0.701      3235\n",
      "\n",
      "     accuracy                          0.740      9842\n",
      "    macro avg      0.740     0.739     0.739      9842\n",
      " weighted avg      0.740     0.740     0.739      9842\n",
      "\n",
      "CPU times: user 22min 14s, sys: 8.09 s, total: 22min 22s\n",
      "Wall time: 22min 18s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "_ = nli.experiment(\n",
    "    train_reader=nli.SNLITrainReader(SNLI_HOME),\n",
    "    phi=simple_chained_rep_rnn_phi,\n",
    "    train_func=fit_optimized_simple_chained_rnn,\n",
    "    assess_reader=nli.SNLIDevReader(SNLI_HOME),\n",
    "    vectorize=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model is close to the word cross-product baseline ($\\approx$0.75), but it's not better. Perhaps using a GloVe embedding would suffice to push it into the lead.\n",
    "\n",
    "The hypothesis-only baseline for this model is very simple: we just use the same model, but we process only the hypothesis."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Separate premise and hypothesis RNNs\n",
    "\n",
    "A natural variation on the above is to give the premise and hypothesis each their own RNN:\n",
    "\n",
    "<img src=\"fig/nli-rnn-chained-separate.png\" width=800 />\n",
    "\n",
    "This greatly increases the number of parameters, but it gives the model more chances to learn that appearing in the premise is different from appearing in the hypothesis. One could even push this idea further by giving the premise and hypothesis their own embeddings as well. This could take the form of a simple modification to [the sentence-encoder version defined above](#Sentence-encoding-RNNs)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention mechanisms\n",
    "\n",
    "Many of the best-performing systems in [the SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) use __attention mechanisms__ to help the model learn important associations between words in the premise and words in the hypothesis. I believe [Rocktäschel et al. (2015)](https://arxiv.org/pdf/1509.06664v1.pdf) were the first to explore such models for NLI.\n",
    "\n",
    "For instance, if _puppy_ appears in the premise and _dog_ in the conclusion, then that might be a high-precision indicator that the correct relationship is entailment.\n",
    "\n",
    "This diagram is a high-level schematic for adding attention mechanisms to a chained RNN model for NLI:\n",
    "\n",
    "<img src=\"fig/nli-rnn-attention.png\" width=800 />\n",
    "\n",
    "Since PyTorch will handle the details of backpropagation, implementing these models is largely reduced to figuring out how to wrangle the states of the model in the desired way."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Error analysis with the MultiNLI annotations\n",
    "\n",
    "The annotations included with the MultiNLI corpus create some powerful yet easy opportunities for error analysis right out of the box. This section illustrates how to make use of them with models you've trained.\n",
    "\n",
    "First, we train a chained RNN model on a sample of the MultiNLI data, just for illustrative purposes. To save time, we'll carry over the optimal model we used above for SNLI. (For a real experiment, of course, we would want to conduct the hyperparameter search again, since MultiNLI is very different from SNLI.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 13. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 809.0821097567677"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.687     0.555     0.614     39052\n",
      "   entailment      0.515     0.648     0.574     39245\n",
      "      neutral      0.539     0.504     0.521     39514\n",
      "\n",
      "     accuracy                          0.569    117811\n",
      "    macro avg      0.580     0.569     0.570    117811\n",
      " weighted avg      0.580     0.569     0.569    117811\n",
      "\n"
     ]
    }
   ],
   "source": [
    "rnn_multinli_experiment = nli.experiment(\n",
    "    train_reader=nli.MultiNLITrainReader(MULTINLI_HOME),\n",
    "    phi=simple_chained_rep_rnn_phi,\n",
    "    train_func=fit_optimized_simple_chained_rnn,\n",
    "    assess_reader=None,\n",
    "    random_state=42,\n",
    "    vectorize=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The return value of `nli.experiment` contains the information we need to make predictions on new examples. \n",
    "\n",
    "Next, we load in the 'matched' condition annotations ('mismatched' would work as well):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "matched_ann_filename = os.path.join(\n",
    "    ANNOTATIONS_HOME,\n",
    "    \"multinli_1.0_matched_annotations.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "matched_ann = nli.read_annotated_subset(\n",
    "    matched_ann_filename, MULTINLI_HOME)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following function uses `rnn_multinli_experiment` to make predictions on annotated examples, and harvests some other information that is useful for error analysis:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_annotated_example(ann, experiment_results):\n",
    "    model = experiment_results['model']\n",
    "    phi = experiment_results['phi']\n",
    "    ex = ann['example']\n",
    "    prem = ex.sentence1_parse\n",
    "    hyp = ex.sentence2_parse\n",
    "    feats = phi(prem, hyp)\n",
    "    pred = model.predict([feats])[0]\n",
    "    gold = ex.gold_label\n",
    "    data = {cat: True for cat in ann['annotations']}\n",
    "    data.update({'gold': gold, 'prediction': pred, 'correct': gold == pred})\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, this function applies `predict_annotated_example` to a collection of annotated examples and puts the results in a `pd.DataFrame` for flexible analysis:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_predictions_for_annotated_data(anns, experiment_results):\n",
    "    data = []\n",
    "    for ex_id, ann in anns.items():\n",
    "        results = predict_annotated_example(ann, experiment_results)\n",
    "        data.append(results)\n",
    "    return pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "ann_analysis_df = get_predictions_for_annotated_data(\n",
    "    matched_ann, rnn_multinli_experiment)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With `ann_analysis_df`, we can see how the model does on individual annotation categories:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>#MODAL</th>\n",
       "      <th>True</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>correct</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>False</th>\n",
       "      <td>52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>True</th>\n",
       "      <td>92</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "#MODAL   True\n",
       "correct      \n",
       "False      52\n",
       "True       92"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.crosstab(ann_analysis_df['correct'], ann_analysis_df['#MODAL'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
